{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## PyTorch实现手写数字识别器\n",
    "> mnist是一个开源的手写数字数据集，借此实现一个简单的手写数字识别的网络\n",
    "\n",
    "相关参考\n",
    "\n",
    "* https://www.jb51.net/article/208404.htm\n",
    "* https://www.jb51.net/article/141074.htm\n",
    "* https://www.jb51.net/article/211872.htm"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 数据的处理\n",
    "> 使用pytorch自带的包进行数据的预处理\n",
    "\n",
    "直接将图片标准化到了-1到1的范围，标准化的原因就是因为如果某个数在数据中很大很大，就导致其权重较大，从而影响到其他数据\n",
    "\n",
    "本身我们的数据都是平等的，所以标准化后将数据分布到-1到1的范围，使得所有数据都不会有太大的权重导致网络出现巨大的波动\n",
    "\n",
    "trainloader现在是一个可迭代的对象，那么我们可以使用for循环进行遍历了，由于是使用yield返回的数据，为了节约内存\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "transform = transforms.Compose([\n",
    "  transforms.ToTensor(),\n",
    "  transforms.Normalize((0.5), (0.5))\n",
    "])\n",
    "# www.di.ens.fr/~lelarge/MNIST.tar.gz\n",
    "train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True,num_workers=2)\n",
    "\n",
    "# 注释\n",
    "# transforms.Normalize 用于数据的标准化，具体实现( mean, std)\n",
    "# mean:均值 总和后除个数\n",
    "# std:方差 每个元素减去均值再平方再除个数\n",
    "# norm_data = (tensor - mean) / std"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 数据的查看"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def imshow(img):\n",
    "   img = img / 2 + 0.5 # unnormalize 传入的是Tensor\n",
    "   npimg = img.numpy()\n",
    "   plt.imshow(np.transpose(npimg, (1, 2, 0))) # 将通道维度置在第三个维度\n",
    "   plt.show()\n",
    "# torchvision.utils.make_grid 将图片进行拼接\n",
    "imshow(torchvision.utils.make_grid(iter(train_data_loader).next()[0]))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 构建网络\n",
    "1. 卷积层使用 torch.nn.Conv2d\n",
    "2. 激活层使用 torch.nn.ReLU\n",
    "3. 池化层使用 torch.nn.MaxPool2d\n",
    "4. 全连接层使用 torch.nn.Linear"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 示例模型 一\n",
    "class Net(nn.Module):\n",
    "  def __init__(self):\n",
    "    super(Net, self).__init__()\n",
    "    self.conv1 = nn.Conv2d(in_channels=1, out_channels=28, kernel_size=5)     # 14\n",
    "    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)                         # 无参数学习因此无需设置两个\n",
    "    self.conv2 = nn.Conv2d(in_channels=28, out_channels=28*2, kernel_size=5)  # 7\n",
    "    self.fc1 = nn.Linear(in_features=28*2*4*4, out_features=1024)\n",
    "    self.fc2 = nn.Linear(in_features=1024, out_features=10) # 最后输出 10 个分类\n",
    "  def forward(self, inputs):                  # Size([32, 1, 28, 28])\n",
    "    x = self.pool(F.relu(self.conv1(inputs))) # Size([32, 28, 12, 12])\n",
    "    x = self.pool(F.relu(self.conv2(x)))      # Size([32, 56, 4, 4])\n",
    "    x = x.view(inputs.size()[0],-1)           # Size([32, 896])\n",
    "    x = F.relu(self.fc1(x))                   # Size([32, 1024])\n",
    "    return self.fc2(x)                        # Size([32, 10])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 示例模型 二\n",
    "class LeNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet, self).__init__()\n",
    "        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),nn.ReLU(), nn.MaxPool2d(2, 2))\n",
    "        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),nn.MaxPool2d(2, 2))\n",
    "        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),nn.BatchNorm1d(120), nn.ReLU())\n",
    "        self.fc2 = nn.Sequential(\n",
    "            nn.Linear(120, 84),\n",
    "            nn.BatchNorm1d(84),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(84, 10))\n",
    "        # 最后的结果一定要变为 10，因为数字的选项是 0 ~ 9\n",
    "    def forward(self, x):                   # Size([32, 1, 28, 28])\n",
    "        x = self.conv1(x)                   # Size([32, 6, 15, 15])\n",
    "        x = self.conv2(x)                   # Size([32, 16, 5, 5])\n",
    "        x = x.view(x.size()[0], -1)         # Size([32, 400]) 对参数实现扁平化\n",
    "        x = self.fc1(x)                     # Size([32, 120])\n",
    "        x = self.fc2(x)                     # Size([32, 10])\n",
    "        return x"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "![](https://img.jbzj.com/file_images/article/202103/2021032611485914.gif)\n",
    "in_channels: 为输入通道数 彩色图片有3个通道 黑白有1个通道\n",
    "out_channels: 输出通道数\n",
    "kernel_size: 卷积核的大小\n",
    "stride: 卷积的步长\n",
    "padding: 外边距大小\n",
    "\n",
    "输出的size计算公式:\n",
    "h = (h - kernel_size + 2*padding)/stride + 1\n",
    "\n",
    "w = (w - kernel_size + 2*padding)/stride + 1\n",
    "\n",
    "MaxPool2d:是没有参数进行运算的"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 实例化网络优化器，并且使用GPU进行训练"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# net = Net()\n",
    "net = LeNet()\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "net.to(device)\n",
    "print(net)\n",
    "for name , parameter in net.named_parameters():\n",
    "    print(name, parameter) # 查看 默认参数\n",
    "# Net(\n",
    "#  (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1))\n",
    "#  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
    "#  (conv2): Conv2d(28, 56, kernel_size=(5, 5), stride=(1, 1))\n",
    "#  (fc1): Linear(in_features=896, out_features=1024, bias=True)\n",
    "#  (fc2): Linear(in_features=1024, out_features=10, bias=True)\n",
    "# )"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 训练\n",
    "\n",
    "* 一般训练模型时，加上model.train() , 会正常使用 Batch Normalization 和 Dropout\n",
    "* 一般训练模型时，加上model.eval() , 不会正常使用 Batch Normalization 和 Dropout"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss() # 损失函数使用交叉熵\n",
    "opt = torch.optim.Adam(params=net.parameters(), lr=0.001) # 优化函数使用 Adam 自适应优化算法\n",
    "for epoch in range(10):\n",
    "  for images, labels in train_data_loader:\n",
    "    images = images.to(device)\n",
    "    labels = labels.to(device)\n",
    "    pre_label = net(images)\n",
    "    loss = criterion(pre_label, labels)\n",
    "    # loss = F.cross_entropy(input=pre_label, target=labels).mean()\n",
    "    pre_label = torch.argmax(pre_label, dim=1) # torch.argmax 计算最大数所在索引值\n",
    "    acc = (pre_label == labels).sum() / torch.tensor(labels.size()[0], dtype=torch.float32)\n",
    "    net.zero_grad()\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "  print(epoch, acc.detach().cpu().numpy(), loss.detach().cpu().numpy())\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# 预测\n",
    "\n",
    "test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True,num_workers=2)\n",
    "images, labels = iter(test_loader).next()\n",
    "images = images.to(device)\n",
    "labels = labels.to(device)\n",
    "with torch.no_grad():\n",
    "  pre_label = net(images)\n",
    "  pre_label = torch.argmax(pre_label, dim=1)\n",
    "  acc = (pre_label==labels).sum() / torch.tensor(labels.size()[0], dtype=torch.float32)\n",
    "  print(acc)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}